/**
 * Boost Software License - Version 1.0 - August 17th, 2003
 *
 * Permission is hereby granted, free of charge, to any person or organization
 * obtaining a copy of the software and accompanying documentation covered by
 * this license (the "Software") to use, reproduce, display, distribute,
 * execute, and transmit the Software, and to prepare derivative works of the
 * Software, and to permit third-parties to whom the Software is furnished to
 * do so, all subject to the following:
 *
 * The copyright notices in the Software and this entire statement, including
 * the above license grant, this restriction and the following disclaimer,
 * must be included in all copies of the Software, in whole or in part, and
 * all derivative works of the Software, unless such copies or derivative
 * works are solely in the form of machine-executable object code generated by
 * a source language processor.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT
 * SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE
 * FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE,
 * ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
 * DEALINGS IN THE SOFTWARE.
 *
 * --------------------------------------------------------------------------
 * \file dnn/src/common/elemwise/erfinv.h
 *
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *
 * This file has been modified by Megvii ("Megvii Modifications").
 * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights reserved.
 *
 * --------------------------------------------------------------------------
 */

#if !__CUDACC__ && !__HIPCC__

#include <cmath>

#include "src/common/utils.h"

//  (C) Copyright John Maddock 2006.
//  Use, modification and distribution are subject to the
//  Boost Software License, Version 1.0. (See accompanying file
//  LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)

template <class T_, class U>
inline U evaluate_polynomial(const T_* poly, U const& z, std::size_t count) {
    megdnn_assert(count > 0);
    U sum = static_cast<U>(poly[count - 1]);
    for (int i = static_cast<int>(count) - 2; i >= 0; --i) {
        sum *= z;
        sum += static_cast<U>(poly[i]);
    }
    return sum;
}

template <std::size_t N, class T, class V>
inline V evaluate_polynomial(const T (&a)[N], const V& val) {
    return evaluate_polynomial(a, val, N);
}

//
// The inverse erf and erfc functions share a common implementation,
// this version is for 80-bit long double's and smaller:
//
inline double erfinv_imp(double p, double q) {
    using namespace std;

    double result = 0;

    if (p <= 0.5) {
        //
        // Evaluate inverse erf using the rational approximation:
        //
        // x = p(p+10)(Y+R(p))
        //
        // Where Y is a constant, and R(p) is optimised for a low
        // absolute error compared to |Y|.
        //
        // double: Max error found: 2.001849e-18
        // long double: Max error found: 1.017064e-20
        // Maximum Deviation Found (actual error term at infinite precision) 8.030e-21
        //
        static const float Y = 0.0891314744949340820313f;
        static const double P[] = {
                -0.000508781949658280665617, -0.00836874819741736770379,
                0.0334806625409744615033,    -0.0126926147662974029034,
                -0.0365637971411762664006,   0.0219878681111168899165,
                0.00822687874676915743155,   -0.00538772965071242932965};
        static const double Q[] = {
                1.0,
                -0.970005043303290640362,
                -1.56574558234175846809,
                1.56221558398423026363,
                0.662328840472002992063,
                -0.71228902341542847553,
                -0.0527396382340099713954,
                0.0795283687341571680018,
                -0.00233393759374190016776,
                0.000886216390456424707504};
        double g = p * (p + 10);
        double r = evaluate_polynomial(P, p) / evaluate_polynomial(Q, p);
        result = g * Y + g * r;
    } else if (q >= 0.25) {
        //
        // Rational approximation for 0.5 > q >= 0.25
        //
        // x = sqrt(-2*log(q)) / (Y + R(q))
        //
        // Where Y is a constant, and R(q) is optimised for a low
        // absolute error compared to Y.
        //
        // double : Max error found: 7.403372e-17
        // long double : Max error found: 6.084616e-20
        // Maximum Deviation Found (error term) 4.811e-20
        //
        static const float Y = 2.249481201171875f;
        static const double P[] = {-0.202433508355938759655, 0.105264680699391713268,
                                   8.37050328343119927838,   17.6447298408374015486,
                                   -18.8510648058714251895,  -44.6382324441786960818,
                                   17.445385985570866523,    21.1294655448340526258,
                                   -3.67192254707729348546};
        static const double Q[] = {
                1.0,
                6.24264124854247537712,
                3.9713437953343869095,
                -28.6608180499800029974,
                -20.1432634680485188801,
                48.5609213108739935468,
                10.8268667355460159008,
                -22.6436933413139721736,
                1.72114765761200282724};
        double g = sqrt(-2 * log(q));
        double xs = q - 0.25f;
        double r = evaluate_polynomial(P, xs) / evaluate_polynomial(Q, xs);
        result = g / (Y + r);
    } else {
        //
        // For q < 0.25 we have a series of rational approximations all
        // of the general form:
        //
        // let: x = sqrt(-log(q))
        //
        // Then the result is given by:
        //
        // x(Y+R(x-B))
        //
        // where Y is a constant, B is the lowest value of x for which
        // the approximation is valid, and R(x-B) is optimised for a low
        // absolute error compared to Y.
        //
        // Note that almost all code will really go through the first
        // or maybe second approximation.  After than we're dealing with very
        // small input values indeed: 80 and 128 bit long double's go all the
        // way down to ~ 1e-5000 so the "tail" is rather long...
        //
        double x = sqrt(-log(q));
        if (x < 3) {
            // Max error found: 1.089051e-20
            static const float Y = 0.807220458984375f;
            static const double P[] = {
                    -0.131102781679951906451,    -0.163794047193317060787,
                    0.117030156341995252019,     0.387079738972604337464,
                    0.337785538912035898924,     0.142869534408157156766,
                    0.0290157910005329060432,    0.00214558995388805277169,
                    -0.679465575181126350155e-6, 0.285225331782217055858e-7,
                    -0.681149956853776992068e-9};
            static const double Q[] = {
                    1.0,
                    3.46625407242567245975,
                    5.38168345707006855425,
                    4.77846592945843778382,
                    2.59301921623620271374,
                    0.848854343457902036425,
                    0.152264338295331783612,
                    0.01105924229346489121};
            double xs = x - 1.125f;
            double R = evaluate_polynomial(P, xs) / evaluate_polynomial(Q, xs);
            result = Y * x + R * x;
        } else if (x < 6) {
            // Max error found: 8.389174e-21
            static const float Y = 0.93995571136474609375f;
            static const double P[] = {
                    -0.0350353787183177984712,  -0.00222426529213447927281,
                    0.0185573306514231072324,   0.00950804701325919603619,
                    0.00187123492819559223345,  0.000157544617424960554631,
                    0.460469890584317994083e-5, -0.230404776911882601748e-9,
                    0.266339227425782031962e-11};
            static const double Q[] = {
                    1.0,
                    1.3653349817554063097,
                    0.762059164553623404043,
                    0.220091105764131249824,
                    0.0341589143670947727934,
                    0.00263861676657015992959,
                    0.764675292302794483503e-4};
            double xs = x - 3;
            double R = evaluate_polynomial(P, xs) / evaluate_polynomial(Q, xs);
            result = Y * x + R * x;
        } else if (x < 18) {
            // Max error found: 1.481312e-19
            static const float Y = 0.98362827301025390625f;
            static const double P[] = {
                    -0.0167431005076633737133,  -0.00112951438745580278863,
                    0.00105628862152492910091,  0.000209386317487588078668,
                    0.149624783758342370182e-4, 0.449696789927706453732e-6,
                    0.462596163522878599135e-8, -0.281128735628831791805e-13,
                    0.99055709973310326855e-16};
            static const double Q[] = {
                    1.0,
                    0.591429344886417493481,
                    0.138151865749083321638,
                    0.0160746087093676504695,
                    0.000964011807005165528527,
                    0.275335474764726041141e-4,
                    0.282243172016108031869e-6};
            double xs = x - 6;
            double R = evaluate_polynomial(P, xs) / evaluate_polynomial(Q, xs);
            result = Y * x + R * x;
        } else if (x < 44) {
            // Max error found: 5.697761e-20
            static const float Y = 0.99714565277099609375f;
            static const double P[] = {
                    -0.0024978212791898131227,   -0.779190719229053954292e-5,
                    0.254723037413027451751e-4,  0.162397777342510920873e-5,
                    0.396341011304801168516e-7,  0.411632831190944208473e-9,
                    0.145596286718675035587e-11, -0.116765012397184275695e-17};
            static const double Q[] = {
                    1.0,
                    0.207123112214422517181,
                    0.0169410838120975906478,
                    0.000690538265622684595676,
                    0.145007359818232637924e-4,
                    0.144437756628144157666e-6,
                    0.509761276599778486139e-9};
            double xs = x - 18;
            double R = evaluate_polynomial(P, xs) / evaluate_polynomial(Q, xs);
            result = Y * x + R * x;
        } else {
            // Max error found: 1.279746e-20
            static const float Y = 0.99941349029541015625f;
            static const double P[] = {
                    -0.000539042911019078575891, -0.28398759004727721098e-6,
                    0.899465114892291446442e-6,  0.229345859265920864296e-7,
                    0.225561444863500149219e-9,  0.947846627503022684216e-12,
                    0.135880130108924861008e-14, -0.348890393399948882918e-21};
            static const double Q[] = {
                    1.0,
                    0.0845746234001899436914,
                    0.00282092984726264681981,
                    0.468292921940894236786e-4,
                    0.399968812193862100054e-6,
                    0.161809290887904476097e-8,
                    0.231558608310259605225e-11};
            double xs = x - 44;
            double R = evaluate_polynomial(P, xs) / evaluate_polynomial(Q, xs);
            result = Y * x + R * x;
        }
    }
    return result;
}

inline double erfcinv(double z) {
    //
    // Begin by testing for domain errors, and other special cases:
    //
    if ((z < 0) || (z > 2))
        return NAN;
    if (z == 0)
        return INFINITY;
    if (z == 2)
        return -INFINITY;
    //
    // Normalise the input, so it's in the range [0,1], we will
    // negate the result if z is outside that range.  This is a simple
    // application of the erfc reflection formula: erfc(-z) = 2 - erfc(z)
    //
    double p, q, s;
    if (z > 1) {
        q = 2 - z;
        p = 1 - q;
        s = -1;
    } else {
        p = 1 - z;
        q = z;
        s = 1;
    }

    //
    // And get the result, negating where required:
    //
    return s * erfinv_imp(p, q);
}

inline double erfinv(double z) {
    //
    // Begin by testing for domain errors, and other special cases:
    //
    if ((z < -1) || (z > 1))
        return NAN;
    if (z == 1)
        return INFINITY;
    if (z == -1)
        return -INFINITY;
    if (z == 0)
        return 0;
    //
    // Normalise the input, so it's in the range [0,1], we will
    // negate the result if z is outside that range.  This is a simple
    // application of the erf reflection formula: erf(-z) = -erf(z)
    //
    double p, q, s;
    if (z < 0) {
        p = -z;
        q = 1 - p;
        s = -1;
    } else {
        p = z;
        q = 1 - z;
        s = 1;
    }

    //
    // And get the result, negating where required:
    //
    return s * erfinv_imp(p, q);
}

inline float erfcinvf(float z) {
    return erfcinv(z);
}

inline float erfinvf(float z) {
    return erfinv(z);
}

#endif  // ifndef __CUDACC__

// vim: ft=cpp syntax=cpp.doxygen
